import torch
from nltk import Tree
import numpy as np

def my_index_put(data, ixs, value):
    ix_put = ixs.split(1, -1)
    for i in range(0, len(ix_put)):
        ix_put[i].squeeze_(len(ixs.size()) - 1)
    data[ix_put] = value


def my_index_select(data, ixs):
    if len(data.size()) < ixs.size(-1):
        print("ix select error...")
        return None
    a = ixs.split(1, -1)
    factor = 1
    ixs_flat_inside = a[-1]
    for i in reversed(range(1, len(a))):
        factor = factor * data.size(i)
        ixs_flat_inside = ixs_flat_inside + factor * a[i - 1]
    ixs_flat_outside = ixs_flat_inside.reshape(-1)
    dim_tmp = 1
    for i in range(0, ixs.size(-1)):
        dim_tmp = dim_tmp * data.size(i)
    t_flat = data.reshape(dim_tmp, -1)
    result = torch.index_select(t_flat, 0, ixs_flat_outside)
    shape_size = list(ixs.size()[0:len(ixs.size())-1] + data.size()[ixs.size(-1):])
    result = result.reshape(shape_size)
    return result


def my_gen_clique_spans(span_list):
    start_2_goldspan = {}
    end_2_goldspan = {}
    max_right = -1
    for span in span_list:
        left, right, label = span
        if right>max_right:
            max_right = right
        if start_2_goldspan.get(left) is None:
            start_2_goldspan[left] = []
        if end_2_goldspan.get(right) is None:
            end_2_goldspan[right] = []
        start_2_goldspan.get(left).append((left, right, label))
        end_2_goldspan.get(right).append((left, right, label))

    gold_index = []

    for i in range(1, max_right):
        if (start_2_goldspan.get(i) is not None) and (end_2_goldspan.get(i) is not None):
            spans_start = start_2_goldspan.get(i)
            spans_end = end_2_goldspan.get(i)
            max_start_spn = None
            max_start_lbl = None
            tmp_len = 0
            for span in spans_start:
                left, right, label = span
                if right - left > tmp_len:
                    tmp_len = right - left
                    max_start_spn = (left, right)
                    max_start_lbl = label
            max_end_spn = None
            max_end_lbl = None
            tmp_len = 0
            for span in spans_end:
                left, right, label = span
                if right - left > tmp_len:
                    tmp_len = right - left
                    max_end_spn = (left, right)
                    max_end_lbl = label
            gold_index.append((max_end_spn[0], max_start_spn[0], max_start_spn[1], max_end_lbl, max_start_lbl))
    return gold_index


class MyEvaluation:

    @staticmethod
    def evaluate(gold_str_list: list, pred_str_list: list):
        assert len(gold_str_list) == len(pred_str_list)

        gold_trees = [Tree.fromstring(s) for s in gold_str_list]
        pred_trees = [Tree.fromstring(s) for s in pred_str_list]
        ret = MyEvaluation.evaluate_trees(gold_trees, pred_trees)
        return ret

    @staticmethod
    def evaluate_trees(gold_trees: list, pred_trees: list):
        assert len(gold_trees) == len(pred_trees)

        
        pos_nums = np.array([0, 0, 0])
        seg_nums = np.array([0, 0, 0])
        cst_nums = np.array([0, 0, 0])
        spn_nums = np.array([0, 0, 0])

        
        for gold_tree, pred_tree in zip(gold_trees, pred_trees):
            gold_pos, gold_seg, gold_cst, gold_spn = analyze_tree(gold_tree)
            pred_pos, pred_seg, pred_cst, pred_spn = analyze_tree(pred_tree)

            pos_nums += np.array(calc_confusion_num(gold_pos, pred_pos))
            seg_nums += np.array(calc_confusion_num(gold_seg, pred_seg))
            cst_nums += np.array(calc_confusion_num(gold_cst, pred_cst))
            spn_nums += np.array(calc_confusion_num(gold_spn, pred_spn))

        
        pos_prf = calc_prf(pos_nums)
        seg_prf = calc_prf(seg_nums)
        cst_prf = calc_prf(cst_nums)
        spn_prf = calc_prf(spn_nums)

        return pos_prf, seg_prf, cst_prf, spn_prf

def analyze_tree(tree):

    pos_labels = []
    cst_labels = []
    _recursive_get_labels(node=tree,
                          i=0,
                          pos_labels=pos_labels,
                          cst_labels=cst_labels)
    seg_labels = [span for pos, span in pos_labels]
    spn_pos = {}
    for pos, span in pos_labels:
        spn_pos[span] = True
    spn_labels = []
    for cst, span in cst_labels:
        if spn_pos.get(span) is None:
            spn_labels.append(span)

    return pos_labels, seg_labels, cst_labels, spn_labels


def _recursive_get_labels(node: Tree, i, pos_labels: list, cst_labels: list):
    if isinstance(node[0], str):
        word = node[0]
        j = i + len(word)
        span = i, j
        pos = node.label()
        pos_label = pos, span
        pos_labels.append(pos_label)
        return j
    else:
        child_j = i
        for child in node:
            child_j = _recursive_get_labels(node=child,
                                            i=child_j,
                                            pos_labels=pos_labels,
                                            cst_labels=cst_labels)
        j = child_j
        span = i, child_j
        cst = node.label()
        cst_label = cst, span
        cst_labels.append(cst_label)
        return j


def calc_confusion_num(golds, preds):
    gold_set = set(golds)
    pred_set = set(preds)
    num_right = len(gold_set.intersection(pred_set))
    num_gold = len(gold_set)
    num_pred = len(pred_set)
    return num_right, num_gold, num_pred


def calc_prf(nums):
    assert np.issubdtype(nums.dtype, np.integer)
    right, gold, pred = nums
    precision = right / pred if pred > 0 else 0
    recall = right / gold if gold > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
    return precision, recall, f1


def my_convert_char_forward(str_input):
    if len(str_input) == 0:
        return ""
    root = Tree.fromstring(str_input)
    nodes = my_post_order(root)
    for node in nodes:
        if my_is_pre_leaf(node):
            children = node[0]
            node.pop()
            for i in children:
                child = Tree.fromstring("(char " + i + ")")
                node.append(child)
    return my_oneline(root)


def my_convert_char_backward(str_input):
    if len(str_input) == 0:
        return ""

    root = Tree.fromstring(str_input)

    def _my_process(node: Tree):
        is_pos = False
        for child in node:
            if my_is_pre_leaf(child):
                is_pos = True
                break
        if is_pos is False:
            for child in node:
                _my_process(child)
        else:
            leaf_str = "".join(node.leaves())
            node.clear()
            node.append(leaf_str)
    _my_process(root)
    return my_oneline(root)


def my_is_leaf(node):
    return not isinstance(node, Tree)


def my_is_pre_leaf(node):
    return (not my_is_leaf(node)) and all(my_is_leaf(child) for child in node)


def my_post_order(node):
    pointer_list = []

    def _recursive_post_order(node, pointer_list):
        if not my_is_leaf(node):
            for child in node:
                _recursive_post_order(child, pointer_list)
            pointer_list.append(node)

    _recursive_post_order(node, pointer_list)
    return pointer_list


def my_oneline(t):
    return t._pformat_flat(nodesep='', parens='()', quotes=False)


